{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# DeepSphere using SHREC17 dataset\n",
    "## Benchmark with Cohen method S2CNN[[1]](http://arxiv.org/abs/1801.10130) and Esteves method[[2]](http://arxiv.org/abs/1711.06721)\n",
    "Multi-class classification of 3D objects, using the interesting property of rotation equivariance.\n",
    "\n",
    "The 3D objects are projected on a unit sphere.\n",
    "Cohen and Esteves use equiangular sampling, while our method use a HEAlpix sampling\n",
    "\n",
    "Several features are collected:\n",
    "* projection ray length (from sphere border to intersection [0, 2])\n",
    "* cos/sin with surface normal\n",
    "* same features using the convex hull of the 3D object\n",
    "\n",
    "### HEALPix sampling - TF dataset pipeline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 0.1 Load libs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import shutil\n",
    "import sys\n",
    "sys.path.append('../..')\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"  # change to chosen GPU to use, nothing if work on CPU\n",
    "\n",
    "import numpy as np\n",
    "import time\n",
    "import matplotlib.pyplot as plt\n",
    "import healpy as hp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from deepsphere import models, experiment_helper, plot, utils\n",
    "from deepsphere.data import LabeledDatasetWithNoise, LabeledDataset\n",
    "import hyperparameters\n",
    "\n",
    "from load_shrec import fix_dataset, Shrec17Dataset, Shrec17DatasetCache, Shrec17DatasetTF"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 0.2 Define parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Nside = 32\n",
    "experiment_type = 'CNN' # 'FCN'\n",
    "ename = '_'+experiment_type\n",
    "datapath = '../../../data/shrec17/' # localisation of the .obj files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "noise_dataset = True    # use perturbed dataset (Cohen and Esteves do the same)\n",
    "augmentation = 1        # number of element per file (1 = no augmentation of dataset)\n",
    "nfeat = 6"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1 Load dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# if datasets are already downloaded but not preprocessed\n",
    "fix = False\n",
    "download = False\n",
    "if fix:\n",
    "    fix_dataset(datapath+'val_perturbed')\n",
    "    fix_dataset(datapath+'test_perturbed')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "download dataset if True, preprocess data and store it in npy files, and load it in a dataset object"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "val_dataset = Shrec17DatasetCache(datapath, 'val', perturbed=noise_dataset, download=download, \n",
    "                                  nside=Nside, nfeat=nfeat, augmentation=1, nfile=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "val_nonrot_dataset = Shrec17DatasetCache(datapath, 'val', perturbed=noise_dataset, download=download, \n",
    "                                         nside=Nside, nfeat=nfeat, experiment='deepsphere_norot', augmentation=1, nfile=None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Use a tensorflow dataset object"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_TFDataset = Shrec17DatasetTF(datapath, 'train', perturbed=noise_dataset, download=download, \n",
    "                                   nside=Nside, nfeat=nfeat, augmentation=augmentation, nfile=None, experiment='deepsphere')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = train_TFDataset.get_tf_dataset(32)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Test iterate over dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import tensorflow as tf\n",
    "# from tqdm import tqdm\n",
    "\n",
    "# #dataset = tf_dataset_file(datapath, dataset, file_pattern, 32, Nside, augmentation)\n",
    "# data_next = dataset.make_one_shot_iterator().get_next()\n",
    "# config = tf.ConfigProto()\n",
    "# config.gpu_options.allow_growth = True\n",
    "# steps = train_TFDataset.N // 32 + 1\n",
    "# with tf.Session(config=config) as sess:\n",
    "#     sess.run(tf.global_variables_initializer())\n",
    "#     try:\n",
    "#         for i in tqdm(range(steps)):\n",
    "#             out = sess.run(data_next)\n",
    "#     except tf.errors.OutOfRangeError:\n",
    "#         print(\"Done\") "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Test time methods"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import time\n",
    "# import tensorflow as tf\n",
    "\n",
    "# #dataset = tf_dataset_file(datapath, dataset, file_pattern, 32, Nside, augmentation)\n",
    "# t_start = time.time()\n",
    "# data_next = dataset.make_one_shot_iterator().get_next()\n",
    "# config = tf.ConfigProto()\n",
    "# config.gpu_options.allow_growth = True\n",
    "# steps = train_TFDataset.N // 32 + 1\n",
    "# with tf.Session(config=config) as sess:\n",
    "#     sess.run(tf.global_variables_initializer())\n",
    "#     try:\n",
    "#         for i in range(steps):\n",
    "#             out = sess.run(data_next)\n",
    "#     except tf.errors.OutOfRangeError:\n",
    "#         print(\"Done\") # Never reach this as will iterate on infinite sets\n",
    "# t_end = time.time()\n",
    "# print(str(t_end-t_start)+\" s\")\n",
    "\n",
    "# train_dataset = Shrec17Dataset(datapath, 'train', perturbed=noise_dataset, download=download, \n",
    "#                                 nside=Nside, augmentation=augmentation, nfile=None, load=False)\n",
    "\n",
    "# # t_start = time.time()\n",
    "# # data_iter = train_dataset.iter(32)\n",
    "# # steps = int(train_dataset.N / 32)\n",
    "# # for i in range(steps):\n",
    "# #     next(data_iter)\n",
    "# #     #feed_dict = {self.ph_data: batch_data, self.ph_labels: batch_labels, self.ph_training: True}\n",
    "# # t_end = time.time()\n",
    "# # print(str(t_end-t_start)+\" s\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "dataset informations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nclass = train_TFDataset.nclass\n",
    "num_elem = train_TFDataset.N\n",
    "#ids_train = train_dataset.ids\n",
    "print('number of class:',nclass,'\\nnumber of elements:',num_elem)#,'\\nfirst id:',ids_train[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2 Classification using DeepSphere"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Use of the Dataset object used for other DeepSphere experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "EXP_NAME = 'shrec17_newGraph_{}feat_{}aug_{}sides{}'.format(nfeat, augmentation, Nside, ename)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load model with hyperparameters chosen.\n",
    "For each experiment, a new EXP_NAME is chosen, and new hyperparameters are store.\n",
    "All informations are present 'DeepSphere/Shrec17/experiments.md'\n",
    "The fastest way to reproduce an experiment is to revert to the commit of the experiment to load the correct files and notebook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "params = hyperparameters.get_params_shrec17(num_elem, EXP_NAME, Nside, nclass, nfeat_in=nfeat, architecture=experiment_type)\n",
    "params[\"tf_dataset\"] = train_TFDataset.get_tf_dataset(params[\"batch_size\"])\n",
    "#params[\"std\"] = [0.001, 0.005, 0.0125, 0.05, 0.15, 0.5]\n",
    "#params[\"full\"] = [True]*6\n",
    "#params[\"extra_loss\"]=True\n",
    "model = models.deepsphere(**params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "shutil.rmtree('summaries/{}/'.format(EXP_NAME), ignore_errors=True)\n",
    "shutil.rmtree('checkpoints/{}/'.format(EXP_NAME), ignore_errors=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Find a correct learning rate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# backup = params.copy()\n",
    "\n",
    "# params, learning_rate = utils.test_learning_rates(params, train_TFDataset.N, 1e-6, 1e-1, num_epochs=20)\n",
    "\n",
    "# shutil.rmtree('summaries/{}/'.format(params['dir_name']), ignore_errors=True)\n",
    "# shutil.rmtree('checkpoints/{}/'.format(params['dir_name']), ignore_errors=True)\n",
    "\n",
    "# model = models.deepsphere(**params)\n",
    "# _, loss_validation, _, _ = model.fit(train_TFDataset, val_dataset, use_tf_dataset=True, cache=True)\n",
    "\n",
    "# params.update(backup)\n",
    "\n",
    "# plt.semilogx(learning_rate, loss_validation, '.-')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# shutil.rmtree('summaries/lr_finder/', ignore_errors=True)\n",
    "# shutil.rmtree('checkpoints/lr_finder/', ignore_errors=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "0.9 seems to be a good learning rate for SGD with current parameters"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.2 Train Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"the number of parameters in the model is: {:,}\".format(model.get_nbr_var()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "accuracy_validation, loss_validation, loss_training, t_step, t_batch = model.fit(train_TFDataset, val_dataset, use_tf_dataset=True, cache=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot.plot_loss(loss_training, loss_validation, t_step, params['eval_frequency'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Remarks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.evaluate(val_dataset, None, cache=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.evaluate(val_nonrot_dataset, None, cache=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ids_val = val_dataset.get_ids()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "probabilities, _ = model.probs(val_dataset, nclass, cache=True)\n",
    "# if augmentation>1:\n",
    "#     probabilities = probabilities.reshape((-1,augmentation,nclass))\n",
    "#     probabilities = probabilities.mean(axis=1)\n",
    "#     ids_val = ids_val[::augmentation]\n",
    "predictions = np.argmax(probabilities, axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# probabilities = model.probs(x_val, nclass)\n",
    "probabilities, _ = model.probs(val_nonrot_dataset, nclass, cache=True)\n",
    "# if augmentation>1:\n",
    "#     probabilities = probabilities.reshape((-1,augmentation,nclass))\n",
    "#     probabilities = probabilities.mean(axis=1)\n",
    "#     ids_val = ids_val[::augmentation]\n",
    "predictions = np.argmax(probabilities, axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from load_shrec import shrec_output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "shrec_output(model.get_descriptor(val_dataset), ids_val, probabilities, datapath, 'results/val_perturbed')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# for every file, find every object with the same class, sorted by most relevance\n",
    "os.makedirs(os.path.join(datapath,'results_aug/val_perturbed'), exist_ok=True)\n",
    "for i,_id in enumerate(ids_val):\n",
    "    idfile = os.path.join(datapath,'results_aug/val_perturbed',_id)\n",
    "    # predictions batchxclass\n",
    "    # pred_class batch == predictions\n",
    "    retrieved = [(probabilities[j, predictions[j]], ids_val[j]) for j in range(len(ids_val)) if predictions[j] == predictions[i]]\n",
    "    retrieved = sorted(retrieved, reverse=True)\n",
    "    retrieved = [i for _, i in retrieved]\n",
    "    with open(idfile, \"w\") as f:\n",
    "        f.write(\"\\n\".join(retrieved))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "NaN appears if remove i==j case"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## test network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_dataset = Shrec17DatasetCache(datapath, 'test', perturbed=noise_dataset, download=download, \n",
    "                                   nside=Nside, augmentation=1, nfile=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_nonrot_dataset = Shrec17DatasetCache(datapath, 'test', perturbed=noise_dataset, download=download, \n",
    "                                          nside=Nside, experiment='deepsphere_norot', augmentation=1, nfile=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "model.evaluate(test_dataset, None, cache=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.evaluate(test_nonrot_dataset, None, cache=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ids_test = test_nonrot_dataset.get_ids()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "probabilities = model.probs(x_test, nclass)\n",
    "# if augmentation>1:\n",
    "#     probabilities = probabilities.reshape((-1,augmentation,nclass))\n",
    "#     probabilities = probabilities.mean(axis=1)\n",
    "predictions = np.argmax(probabilities, axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "probabilities, _ = model.probs(test_nonrot_dataset, nclass, cache=True)\n",
    "# if augmentation>1:\n",
    "#     probabilities = probabilities.reshape((-1,augmentation,nclass))\n",
    "#     probabilities = probabilities.mean(axis=1)\n",
    "predictions = np.argmax(probabilities, axis=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "write to file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# for every file, find every object with the same class, sorted by most relevance\n",
    "os.makedirs(os.path.join(datapath,'results_aug/test_perturbed'), exist_ok=True)\n",
    "for i, _id in enumerate(ids_test):\n",
    "    idfile = os.path.join(datapath,'results_aug/test_perturbed',_id)\n",
    "    # predictions batchxclass\n",
    "    # pred_class batch == predictions\n",
    "    retrieved = [(probabilities[j, predictions[j]], ids_test[j]) for j in range(len(ids_test)) if predictions[j] == predictions[i]]\n",
    "    retrieved = sorted(retrieved, reverse=True)\n",
    "    retrieved = [i for _, i in retrieved]\n",
    "    with open(idfile, \"w\") as f:\n",
    "        f.write(\"\\n\".join(retrieved))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "shrec_output(model.get_descriptor(LabeledDataset(x_test, labels_test)), ids_test, probabilities, datapath, 'results/test_perturbed')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Why not working?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _print_histogram(nclass, labels_train, labels_min=None, ylim=1700):\n",
    "    if labels_train is None:\n",
    "        return\n",
    "    import matplotlib.pyplot as plt\n",
    "    from collections import Counter\n",
    "    hist_train=Counter(labels_train)\n",
    "    if labels_min is not None:\n",
    "        hist_min = Counter(labels_min)\n",
    "        hist_temp = hist_train - hist_min\n",
    "        hist_min = hist_min - hist_train\n",
    "        hist_train = hist_temp + hist_min\n",
    "#         for i in range(self.nclass):\n",
    "#             hist_train.append(np.sum(labels_train == i))\n",
    "    labels, values = zip(*hist_train.items())\n",
    "    indexes = np.asarray(labels)\n",
    "#     miss = set(indexes) - set(labels)\n",
    "#     if len(miss) is not 0:\n",
    "#         hist_train.update({elem:0 for elem in miss})\n",
    "#     labels, values = zip(*hist_train.items())\n",
    "    width = 1\n",
    "    plt.bar(labels, values, width)\n",
    "    plt.title(\"labels distribution\")\n",
    "    plt.ylim(0,ylim)\n",
    "    #plt.xticks(indexes + width * 0.5, labels)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_print_histogram(55, labels_test)\n",
    "_print_histogram(55, predictions)\n",
    "_print_histogram(55, labels_test, predictions, ylim=120)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import confusion_matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.spy(confusion_matrix(labels_test, predictions, range(55)), cmap = plt.cm.gist_heat_r)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dict_labels = {0: 'airplane',\n",
    "               1: 'trashcan',\n",
    "               2: 'bag',\n",
    "               3: 'basket',\n",
    "               4: 'bathtub',\n",
    "               5: 'bed',\n",
    "               6: 'bench',\n",
    "               7: 'birdhouse',\n",
    "               8: 'bookshelf',\n",
    "               9: 'bottle',\n",
    "               10: 'bowl',\n",
    "               11: 'bus',\n",
    "               12: 'cabinet',\n",
    "               13: 'camera',\n",
    "               14: 'can',\n",
    "               15: 'cap',\n",
    "               16: 'car',\n",
    "               17: 'cellphone',\n",
    "               18: 'chair',\n",
    "               19: 'clock',\n",
    "               20: 'keyboard',\n",
    "               21: 'dishwasher',\n",
    "               22: 'display',\n",
    "               23: 'earphone',\n",
    "               24: 'faucet',\n",
    "               25: 'file cabinet',\n",
    "               26: 'guitar',\n",
    "               27: 'helmet',\n",
    "               28: 'jar',\n",
    "               29: 'knife',\n",
    "               30: 'lamp',\n",
    "               31: 'laptop',\n",
    "               32: 'speaker',\n",
    "               33: 'mailbox',\n",
    "               34: 'microphone',\n",
    "               35: 'microwave',\n",
    "               36: 'motorcycle',\n",
    "               37: 'mug',\n",
    "               38: 'piano',\n",
    "               39: 'pillow',\n",
    "               40: 'pistol',\n",
    "               41: 'flowerpot',\n",
    "               42: 'printer',\n",
    "               43: 'remote control',\n",
    "               44: 'rifle',\n",
    "               45: 'rocket',\n",
    "               46: 'skateboard',\n",
    "               47: 'sofa',\n",
    "               48: 'stove',\n",
    "               49: 'table',\n",
    "               50: 'telephone',\n",
    "               51: 'tower',\n",
    "               52: 'train',\n",
    "               53: 'vessel',\n",
    "               54: 'washer'}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lab1 = np.argsort((conf_mat-np.diag(np.diag(conf_mat))).sum(axis=1))[::-1][:7]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lab2 = np.argsort((conf_mat-np.diag(np.diag(conf_mat))).sum(axis=0))[::-1][:7]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "[dict_labels[lab] for lab in lab1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "[dict_labels[lab] for lab in lab2]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
